{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "45ff14a2",
   "metadata": {},
   "source": [
    "## What可视化\n",
    "\n",
    "2D Grand CAM可视化模块，可视化的结果会输出到当前目录的viz文件夹下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7d998c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "## 获得视频教程\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('What模型可视化')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7fa11b3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\n",
    "\n",
    "import monai\n",
    "from glob import glob\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "from datetime import datetime\n",
    "from onekey_algo import OnekeyDS as okds\n",
    "from onekey_algo import get_param_in_cwd\n",
    "\n",
    "# 模型日志位置，如果没有更改默认保存位置，并且模型是当天训练出来的，可以不动这个参数。\n",
    "model_root = get_param_in_cwd('model_root', 'models')\n",
    "# 模型名称\n",
    "model_name = get_param_in_cwd('model_name', 'resnet50')\n",
    "# 设置为使用OKT-crop_max_roi裁剪出来的图片路径，或者你自己训练数据的路径。\n",
    "mydir = get_param_in_cwd('data_pattern') \n",
    "samples = glob(os.path.join(mydir, '*', '*.jpg'))\n",
    "samples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a23129e7",
   "metadata": {},
   "source": [
    "## 确定可视化模型\n",
    "\n",
    "通过关键词获取要提取那一层进行可视化。\n",
    "\n",
    "### 支持的模型名称\n",
    "\n",
    "模型名称替换代码中的 `model_name`变量的值。\n",
    "\n",
    "| **模型系列** | **模型名称**                                                 |\n",
    "| ------------ | ------------------------------------------------------------ |\n",
    "| AlexNet      | alexnet                                                      |\n",
    "| VGG          | vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19 |\n",
    "| ResNet       | resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2 |\n",
    "| DenseNet     | densenet121, densenet169, densenet201, densenet161           |\n",
    "| Inception    | googlenet, inception_v3                                      |\n",
    "| SqueezeNet   | squeezenet1_0, squeezenet1_1                                 |\n",
    "| ShuffleNetV2 | shufflenet_v2_x2_0, shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5 |\n",
    "| MobileNet    | mobilenet_v2, mobilenet_v3_large, mobilenet_v3_small         |\n",
    "| MNASNet      | mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3              |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f9a9194",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp2 import extract, init_from_model, init_from_onekey\n",
    "\n",
    "model, transformer, device = init_from_onekey(os.path.join(model_root, model_name, 'viz'))\n",
    "for n, m in model.named_modules():\n",
    "    print('Feature name:', n, \"|| Module:\", m)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e46336a1",
   "metadata": {},
   "source": [
    "## 可视化卷积层\n",
    "\n",
    "`Feature name:` 之后的名称为要可视化的层，例如`layer4.2.conv3`, 一般深度学习特征提取最后一层卷积层\n",
    "\n",
    "** 注意 ** : 可视化的层，一定为带有`conv`的卷积层，而且一般是最后一层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d9ad6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 这里的可视化的层需要设置。\n",
    "target_layer = \"layer4.1.conv2\"\n",
    "gradcam = monai.visualize.GradCAM(nn_module=model, target_layers=target_layer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39c9708b",
   "metadata": {},
   "source": [
    "## 打印可视化界面"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9eb2b98",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from onekey_algo.datasets.image_loader import default_loader\n",
    "from onekey_algo.custom.components.comp2 import show_cam_on_image\n",
    "import torch\n",
    "import os\n",
    "import random\n",
    "\n",
    "random.shuffle(samples)\n",
    "viz_dir = r'viz'\n",
    "os.makedirs(viz_dir, exist_ok=True)\n",
    "for sample in samples:\n",
    "    img = default_loader(sample)\n",
    "    sample_ = transformer(img)\n",
    "    sample_  = sample_.view(1, *sample_.size()).to(device)\n",
    "    res_cam = gradcam(x=sample_, class_idx=None)\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(20, 10), facecolor='white')\n",
    "    axes[0].imshow(img.resize(sample_.size()[2:]))\n",
    "    axes[0].axis('off')\n",
    "    axes[1].imshow(show_cam_on_image(img.resize(sample_.size()[2:]), -res_cam[0][0].cpu(), use_rgb=True, reverse=False))\n",
    "    axes[1].axis('off')\n",
    "    plt.savefig(f'{viz_dir}/{os.path.basename(sample)}', bbox_inches = 'tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5048302a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
